{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据准备和增强"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一、数据准备和增强\n",
    "\n",
    "并不是所有的数据都能一次性的加载进内存中！"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***二进制数据(protobuf)***\n",
    "\n",
    "TFRecords是一种二进制文件，不需要单独的标签文件，可以更好的利用内存。在tensorflow中快速的复制，移动，读取，存储等。\n",
    "\n",
    "TFRecords的使用主要有如下几个方法步骤：\n",
    "\n",
    "1. 生成TFRecords文件。\n",
    " - 定义Example，数据格式定义。\n",
    " - 然后将数据序列化为string。\n",
    " - 将序列化后的string写入到TFRecords中。\n",
    "2. 解析TFRecords文件。\n",
    "3. 训练过程中调用TFRecords文件。\n",
    "\n",
    "protobuf的样子大概如下：\n",
    "\n",
    "```proto\n",
    "syntax = \"proto3\";\n",
    "message Person {\n",
    "    string name = 1;\n",
    "    int32 id = 2;\n",
    "    repeated string email = 3;\n",
    "}\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pillow in /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages (6.1.0)\n",
      "Requirement already satisfied: tqdm in /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages (4.32.1)\n"
     ]
    }
   ],
   "source": [
    "!pip install pillow\n",
    "!pip install tqdm\n",
    "!pip install glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import sys\n",
    "import glob\n",
    "import tensorflow as tf\n",
    "cwd = os.path.join(os.getcwd(),'mnist','train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_list = glob.glob(os.path.join(cwd,'0')+'/*.png')\n",
    "label_list = np.zeros(len(image_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#图像标签生成tfrecord函数\n",
    "def image2tfrecord(image_list,label_list,filename):\n",
    "    '''\n",
    "    image_list:image path list\n",
    "    label_list:label list\n",
    "    '''\n",
    "    length=len(image_list)\n",
    "    writer=tf.python_io.TFRecordWriter(filename)\n",
    "    for i in range(length):\n",
    "        if i % 100==0:\n",
    "            ratio=round(i/float(length),4)\n",
    "            sys.stdout.write('ratio:{}\\r'.format(ratio))\n",
    "            sys.stdout.flush()\n",
    "        image=Image.open(image_list[i])\n",
    "        if 'png' in image_list[i][-4:]:\n",
    "            if image.mode=='RGB':\n",
    "                r, g, b = image.split()\n",
    "                image = Image.merge(\"RGB\", (r, g, b))\n",
    "            elif image.mode=='L':\n",
    "                pass\n",
    "            else:\n",
    "                r,g, b, a = image.split()\n",
    "                image = Image.merge(\"RGB\", (r, g, b))\n",
    "        image=image.resize((28,28))\n",
    "        #这个地方就展开了\n",
    "        image_bytes=image.tobytes()\n",
    "        features={}\n",
    "        features['image']=tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))\n",
    "        features['label']=tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label_list[i])]))\n",
    "        tf_features=tf.train.Features(feature=features)\n",
    "        tf_example=tf.train.Example(features=tf_features)\n",
    "        tf_serialized=tf_example.SerializeToString()\n",
    "        writer.write(tf_serialized)\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ratio:0.9961\r"
     ]
    }
   ],
   "source": [
    "image2tfrecord(image_list,label_list,'train0.tfrecords')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/Users/wanjun/Desktop/深度之眼/编程/mnist/train'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=====\n",
      "0\n",
      "=====:0.9961\n",
      "1\n",
      "=====:0.9938\n",
      "2\n",
      "=====:0.9903\n",
      "3\n",
      "=====:0.9949\n",
      "4\n",
      "=====:0.9928\n",
      "5\n",
      "=====:0.9961\n",
      "6\n",
      "=====:0.9971\n",
      "7\n",
      "=====:0.9896\n",
      "8\n",
      "=====:0.9913\n",
      "9\n",
      "ratio:0.9918\r"
     ]
    }
   ],
   "source": [
    "# 生成10个tfrecords\n",
    "for i in range(10):\n",
    "    print('=====')\n",
    "    print(i)\n",
    "    image_list = glob.glob(os.path.join(cwd,'{}'.format(i))+'/*.png')\n",
    "    label_list = np.zeros(len(image_list))+i\n",
    "    image2tfrecord(image_list,label_list,'train_{}.tfrecords'.format(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "#解析函数\n",
    "def pares_tf(example_proto):\n",
    "    dics={}\n",
    "    dics['label']=tf.FixedLenFeature((),dtype=tf.int64,default_value=0)\n",
    "    dics['image']=tf.FixedLenFeature((),dtype=tf.string,default_value=\"\")\n",
    "\n",
    "    parsed_example=tf.parse_single_example(serialized=example_proto,features=dics)\n",
    "    image=tf.decode_raw(parsed_example['image'],out_type=tf.uint8)\n",
    "    #image=tf.image.decode_jpeg(parsed_example['image'], channels=1)\n",
    "    #这个地方可以加一些操作\n",
    "    \n",
    "    image=tf.cast(image,tf.float32)/255\n",
    "    image=tf.reshape(image,(28,28,1))\n",
    "    image=pre_process(image)\n",
    "    #标签的操作\n",
    "    label=parsed_example['label']\n",
    "    label=tf.cast(label,tf.int32)\n",
    "    label = tf.one_hot(label,depth=10,on_value=1.0,off_value=0.0)\n",
    "    return image,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "#dataset\n",
    "def dataset(filenames,batch_size,epochs):\n",
    "    dataset=tf.data.TFRecordDataset(filenames=filenames)\n",
    "    new_dataset=dataset.map(pares_tf)\n",
    "    shuffle_dataset=new_dataset.shuffle(buffer_size=(100000))\n",
    "    batch_dataset=shuffle_dataset.batch(batch_size).repeat(epochs)\n",
    "    batch_dataset=batch_dataset.prefetch(1)\n",
    "    iterator=batch_dataset.make_one_shot_iterator()\n",
    "    next_element=iterator.get_next()\n",
    "    return next_element"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***关于prefetch的含义***\n",
    "\n",
    "<img src=\"prefetch.png\" style=\"zoom:50%\" >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 图像预处理\n",
    "def pre_process(images,random_flip_up_down=False,random_flip_left_right=False,random_brightness=True,random_contrast=True,random_saturation=False,random_hue=False):\n",
    "    if random_flip_up_down:\n",
    "        images = tf.image.random_flip_up_down(images)\n",
    "    if random_flip_left_right:\n",
    "        images = tf.image.random_flip_left_right(images)\n",
    "    if random_brightness:\n",
    "        images = tf.image.random_brightness(images, max_delta=0.2)\n",
    "    if random_contrast:\n",
    "        images = tf.image.random_contrast(images, 0.9, 1.1)\n",
    "    if random_saturation:\n",
    "        images = tf.image.random_saturation(images, 0.3, 0.5)\n",
    "    if random_hue:\n",
    "        images = tf.image.random_hue(images,0.2)\n",
    "    new_size = tf.constant([28,28],dtype=tf.int32)\n",
    "    images = tf.image.resize_images(images, new_size)\n",
    "    return images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "filenames=['train0.tfrecords']\n",
    "next_element=dataset(filenames,batch_size=5,epochs=1)\n",
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 28, 1)\n",
      "(28, 28, 1)\n",
      "(28, 28, 1)\n",
      "(28, 28, 1)\n",
      "(28, 28, 1)\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    batch_images,batch_labels=sess.run([next_element[0],next_element[1]])\n",
    "    for i in range(batch_images.shape[0]):\n",
    "        print(batch_images[i].shape)\n",
    "        img = np.array(batch_images[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "### 二、 使用生成器的方法读取数据训练神经网络\n",
    "\n",
    "任务：使用mnist数据集来实现图像的分类：\n",
    "\n",
    "<img src=\"mnist.png\" style=\"zoom:50%\" >\n",
    "\n",
    "输入是以下的一张图片：\n",
    "\n",
    "<img src=\"0.png\" style=\"zoom:100%\" >\n",
    "\n",
    "等价于一个矩阵：\n",
    "\n",
    "<img src=\"0_matrix.png\" style=\"zoom:15%\" >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-31-ac7c8520d8a4>:14: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From /Users/wanjun/anaconda/envs/python36/lib/python3.6/site-packages/tensorflow/python/ops/losses/losses_impl.py:209: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "epoch:1,train_loss:0.1908,test_acc:0.9375\n",
      "epoch:2,train_loss:0.1843,test_acc:0.9551\n",
      "epoch:3,train_loss:0.1266,test_acc:0.9607\n",
      "epoch:4,train_loss:0.1294,test_acc:0.9622\n",
      "epoch:5,train_loss:0.1487,test_acc:0.9657\n",
      "epoch:6,train_loss:0.0421,test_acc:0.9703\n",
      "epoch:7,train_loss:0.0970,test_acc:0.9697\n",
      "epoch:8,train_loss:0.0249,test_acc:0.9730\n",
      "epoch:9,train_loss:0.0297,test_acc:0.9742\n",
      "epoch:10,train_loss:0.0278,test_acc:0.9735\n",
      "epoch:11,train_loss:0.0328,test_acc:0.9720\n",
      "epoch:12,train_loss:0.0095,test_acc:0.9760\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import sys\n",
    "\n",
    "\n",
    "tf.reset_default_graph()\n",
    "epochs = 15\n",
    "batch_size = 100\n",
    "total_sum = 0\n",
    "epoch = 0\n",
    "filenames=['train_{}.tfrecords'.format(i) for i in range(10)]\n",
    "next_element=dataset(filenames,batch_size=batch_size,epochs=epochs)\n",
    "\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "train_num = mnist.train.num_examples\n",
    "\n",
    "\n",
    "\n",
    "input_data = tf.placeholder(tf.float32,shape=(None,784))\n",
    "input_label = tf.placeholder(tf.float32,shape=(None,10))\n",
    "\n",
    "w1 = tf.get_variable(shape=(784,64),name='hidden_1_w')\n",
    "b1 = tf.get_variable(shape=(64),initializer=tf.zeros_initializer(),name='hidden_1_b')\n",
    "\n",
    "w2 = tf.get_variable(shape=(64,32),name='hidden_2_w')\n",
    "b2 = tf.get_variable(shape=(32),initializer=tf.zeros_initializer(),name='hidden_2_b')\n",
    "\n",
    "w3 = tf.get_variable(shape=(32,10),name='layer_output')\n",
    "\n",
    "#logit层\n",
    "output = tf.matmul(tf.nn.relu(tf.matmul(tf.nn.relu(tf.matmul(input_data,w1)+b1),w2)+b2),w3)\n",
    "\n",
    "loss = tf.losses.softmax_cross_entropy(input_label,output)\n",
    "\n",
    "#opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)\n",
    "opt = tf.train.AdamOptimizer()\n",
    "\n",
    "train_op = opt.minimize(loss)\n",
    "\n",
    "# 测试评估\n",
    "correct_pred = tf.equal(tf.argmax(input_label,axis=1),tf.argmax(output,axis=1))\n",
    "acc = tf.reduce_mean(tf.cast(correct_pred,tf.float32))\n",
    "\n",
    "tf.add_to_collection('my_op',input_data)\n",
    "tf.add_to_collection('my_op',output)\n",
    "tf.add_to_collection('my_op',loss)\n",
    "\n",
    "init = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()\n",
    "with tf.Session() as sess:\n",
    "    sess.run([init])\n",
    "    test_data = mnist.test.images\n",
    "    test_label = mnist.test.labels\n",
    "    while epoch<epochs:\n",
    "        data,label=sess.run([next_element[0],next_element[1]])\n",
    "        data = data.reshape(-1,784)\n",
    "        total_sum+=batch_size\n",
    "        sess.run([train_op],feed_dict={input_data:data,input_label:label})\n",
    "        if total_sum//train_num>epoch:\n",
    "            epoch = total_sum//train_num\n",
    "            loss_val = sess.run([loss],feed_dict={input_data:data,input_label:label})\n",
    "            acc_test = sess.run([acc],feed_dict={input_data:test_data,input_label:test_label})\n",
    "            saver.save(sess, save_path=\"./model/my_model.ckpt\")\n",
    "            print('epoch:{},train_loss:{:.4f},test_acc:{:.4f}'.format(epoch,loss_val[0],acc_test[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x1a4edb8978>"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADOhJREFUeJzt3W+IXPW9x/HPx70NaFpITDANaa72Frl6yYO0LnrBi1rUkiuBGLSSKCUXStMHFSxEqMQHDUJBSvonjwpbXBqhtS2kvQkYtUsQbOEiRg3RNjaVujZr1qQhSg0iUfO9D/ZEtnHnN+PMmTmz+32/IOzM+Z4z58uQz5wzc/78HBECkM9FTTcAoBmEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUv8yyJXZ5nRCoM8iwp3M19OW3/Y623+2/artB3p5LQCD5W7P7bc9IumopFslTUl6TtLmiPhTYRm2/ECfDWLLf62kVyPirxFxVtIvJW3o4fUADFAv4V8l6dis51PVtH9ie6vtg7YP9rAuADXr5Qe/uXYtPrZbHxFjksYkdvuBYdLLln9K0upZzz8n6Xhv7QAYlF7C/5ykK21/3vYiSZsk7aunLQD91vVuf0R8YPteSU9JGpE0HhF/rK0zAH3V9aG+rlbGd36g7wZykg+A+YvwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpLoeoluSbE9KekfSh5I+iIjROpoC0H89hb/y5Yg4VcPrABggdvuBpHoNf0j6ne3nbW+toyEAg9Hrbv/1EXHc9mWSJmy/EhHPzJ6h+lDggwEYMo6Iel7I3iHpTETsLMxTz8oAtBQR7mS+rnf7bS+2/ZnzjyV9RdLL3b4egMHqZbd/haTf2j7/Or+IiCdr6QpA39W229/RytjtB/qu77v9AOY3wg8kRfiBpAg/kBThB5Ii/EBSdVzVhyG2aNGiYv3mm28u1u++++5ifdmyZcX6unXrivVePPlk+bSS2267rW/rXgjY8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUhznXwCuuuqqlrVdu3YVl73llluK9ep+DS21uyT80KFDLWtLliwpLnv55Zf3tG6UseUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQ4zj8EtmzZUqxfc801xfrmzZtb1i66qPz5PjExUazv2bOnWH/66aeL9TNnzrSsHThwoLhsO4cPH+5p+ezY8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUm2H6LY9Lmm9pJMRsaaadqmkX0m6QtKkpLsi4q22K0s6RHe7e9fv37+/WH///feL9SeeeKJl7b777isu+/rrrxfrvVq/fn3L2t69e3t67RUrVhTrp06d6un156s6h+j+maQL//c+IOlARFwp6UD1HMA80jb8EfGMpNMXTN4gaXf1eLek22vuC0Cfdfudf0VETEtS9fey+loCMAh9P7ff9lZJW/u9HgCfTLdb/hO2V0pS9fdkqxkjYiwiRiNitMt1AeiDbsO/T9L5S9G2SOrtZ1sAA9c2/LYfk/R/kv7d9pTtr0t6WNKttv8i6dbqOYB5pO13/ohodbF4eWB3fOSOO+7oafmdO3cW6w8++GBPr99P999/f9fLvvbaa8V61uP4deEMPyApwg8kRfiBpAg/kBThB5Ii/EBS3Lp7AKanp4v18fHxYv2hhx6qs51aXXfddcX6DTfc0LJ29uzZ4rKbNm3qqid0hi0/kBThB5Ii/EBShB9IivADSRF+ICnCDyTV9tbdta4s6a27F7J25yiUhh8/evRocdmrr766q56yq/PW3QAWIMIPJEX4gaQIP5AU4QeSIvxAUoQfSIrj/Ci68cYbi/WJiYli/e23325ZK13rL0mvvPJKsY65cZwfQBHhB5Ii/EBShB9IivADSRF+ICnCDyTV9r79tsclrZd0MiLWVNN2SPqGpL9Xs22PiP39ahLNWbp0abE+MjJSrL/55pstaxzHb1YnW/6fSVo3x/QfRcTa6h/BB+aZtuGPiGcknR5ALwAGqJfv/PfaPmx73HZ53xDA0Ok2/D+R9AVJayVNS/pBqxltb7V90PbBLtcFoA+6Cn9EnIiIDyPinKSfSrq2MO9YRIxGxGi3TQKoX1fht71y1tONkl6upx0Ag9LJob7HJN0kabntKUnflXST7bWSQtKkpG/2sUcAfdA2/BGxeY7Jj/ShFwyhbdu29bT8nj17auoEdeMMPyApwg8kRfiBpAg/kBThB5Ii/EBS3Lo7uVWrVhXrU1NTxfpbb71VrK9Zs6Zl7fjx48Vl0R1u3Q2giPADSRF+ICnCDyRF+IGkCD+QFOEHkmp7SS8WtnaX7LY7D2Tjxo3FOsfyhxdbfiApwg8kRfiBpAg/kBThB5Ii/EBShB9Iiuv5F7glS5YU68eOHSvWFy9eXKwvX768WD99mjFeB43r+QEUEX4gKcIPJEX4gaQIP5AU4QeSIvxAUm2v57e9WtKjkj4r6ZyksYjYZftSSb+SdIWkSUl3RUT5Ju4YuIsvvrhYv+SSS4r1ycnJYv299977pC1hSHSy5f9A0raIuFrSf0r6lu3/kPSApAMRcaWkA9VzAPNE2/BHxHREvFA9fkfSEUmrJG2QtLuabbek2/vVJID6faLv/LavkPRFSc9KWhER09LMB4Sky+puDkD/dHwPP9uflrRH0rcj4h92R6cPy/ZWSVu7aw9Av3S05bf9Kc0E/+cR8Ztq8gnbK6v6Skkn51o2IsYiYjQiRutoGEA92obfM5v4RyQdiYgfzirtk7SlerxF0t762wPQL53s9l8v6WuSXrJ9qJq2XdLDkn5t++uS/ibpq/1pEb248847e1r+8ccfL9bffffdnl4fzWkb/oj4g6RWX/BvrrcdAIPCGX5AUoQfSIrwA0kRfiApwg8kRfiBpBiie4G75557ivV2p2k/++yzdbaDIcKWH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS4jj/AtduCPZ29VOnTtXZDoYIW34gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSMrtjvPWujJ7cCtLZNmyZS1rL774YnHZVatWFesjIyNd9YTmRERHY+mx5QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpNpez297taRHJX1W0jlJYxGxy/YOSd+Q9Pdq1u0Rsb9fjaK1pUuXtqy1O46PvDq5mccHkrZFxAu2PyPpedsTVe1HEbGzf+0B6Je24Y+IaUnT1eN3bB+RxOYEmOc+0Xd+21dI+qKk82M43Wv7sO1x23Pue9reavug7YM9dQqgVh2H3/anJe2R9O2I+Iekn0j6gqS1mtkz+MFcy0XEWESMRsRoDf0CqElH4bf9Kc0E/+cR8RtJiogTEfFhRJyT9FNJ1/avTQB1axt+zwzj+oikIxHxw1nTV86abaOkl+tvD0C/dPJr//WSvibpJduHqmnbJW22vVZSSJqU9M2+dIi2pqenW9aeeuqp4rJvvPFG3e1gnujk1/4/SJrr+mCO6QPzGGf4AUkRfiApwg8kRfiBpAg/kBThB5Li1t3AAsOtuwEUEX4gKcIPJEX4gaQIP5AU4QeSIvxAUp1cz1+nU5Jen/V8eTVtGA1rb8Pal0Rv3aqzt8s7nXGgJ/l8bOX2wWG9t9+w9jasfUn01q2memO3H0iK8ANJNR3+sYbXXzKsvQ1rXxK9dauR3hr9zg+gOU1v+QE0pJHw215n+8+2X7X9QBM9tGJ70vZLtg81PcRYNQzaSdsvz5p2qe0J23+p/rYeonfwve2w/Ub13h2yfVtDva22/bTtI7b/aPu+anqj712hr0bet4Hv9tsekXRU0q2SpiQ9J2lzRPxpoI20YHtS0mhENH5M2PYNks5IejQi1lTTvi/pdEQ8XH1wLo2I7wxJbzsknWl65OZqQJmVs0eWlnS7pP9Rg+9doa+71MD71sSW/1pJr0bEXyPirKRfStrQQB9DLyKekXT6gskbJO2uHu/WzH+egWvR21CIiOmIeKF6/I6k8yNLN/reFfpqRBPhXyXp2KznUxquIb9D0u9sP297a9PNzGFFNWz6+eHTL2u4nwu1Hbl5kC4YWXpo3rtuRryuWxPhn+sWQ8N0yOH6iPiSpP+W9K1q9xad6Wjk5kGZY2TpodDtiNd1ayL8U5JWz3r+OUnHG+hjThFxvPp7UtJvNXyjD584P0hq9fdkw/18ZJhGbp5rZGkNwXs3TCNeNxH+5yRdafvzthdJ2iRpXwN9fIztxdUPMbK9WNJXNHyjD++TtKV6vEXS3gZ7+SfDMnJzq5Gl1fB7N2wjXjdykk91KOPHkkYkjUfE9wbexBxs/5tmtvbSzBWPv2iyN9uPSbpJM1d9nZD0XUn/K+nXkv5V0t8kfTUiBv7DW4vebtLMrutHIzef/4494N7+S9LvJb0k6Vw1ebtmvl839t4V+tqsBt43zvADkuIMPyApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSf0/P4DMdhH7QBsAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "index = 666\n",
    "plt.imshow(test_data[index].reshape(28,28),cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./model/my_model.ckpt\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "saver = tf.train.import_meta_graph('./model/my_model.ckpt.meta')\n",
    "saver.restore(sess,\"./model/my_model.ckpt\")\n",
    "input_tensor = tf.get_collection('my_op')[0]\n",
    "output_tensor = tf.get_collection('my_op')[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "### 三、 Tensorflow和Keras的混合使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "epoch:1,train_loss:0.0848,test_acc:0.9731\n",
      "epoch:2,train_loss:0.0854,test_acc:0.9833\n",
      "epoch:3,train_loss:0.0118,test_acc:0.9840\n",
      "epoch:4,train_loss:0.0443,test_acc:0.9830\n",
      "epoch:5,train_loss:0.0137,test_acc:0.9853\n",
      "epoch:6,train_loss:0.0154,test_acc:0.9876\n",
      "epoch:7,train_loss:0.0015,test_acc:0.9875\n",
      "epoch:8,train_loss:0.0066,test_acc:0.9867\n",
      "epoch:9,train_loss:0.0019,test_acc:0.9906\n",
      "epoch:10,train_loss:0.0048,test_acc:0.9890\n",
      "epoch:11,train_loss:0.0003,test_acc:0.9897\n",
      "epoch:12,train_loss:0.0007,test_acc:0.9890\n",
      "epoch:13,train_loss:0.0037,test_acc:0.9894\n",
      "epoch:14,train_loss:0.0017,test_acc:0.9882\n",
      "epoch:15,train_loss:0.0006,test_acc:0.9870\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import sys\n",
    "\n",
    "\n",
    "tf.reset_default_graph()\n",
    "epochs = 15\n",
    "batch_size = 100\n",
    "total_sum = 0\n",
    "epoch = 0\n",
    "filenames=['train_{}.tfrecords'.format(i) for i in range(10)]\n",
    "next_element=dataset(filenames,batch_size=batch_size,epochs=epochs)\n",
    "\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "train_num = mnist.train.num_examples\n",
    "\n",
    "input_data = tf.placeholder(tf.float32,shape=(None,28,28,1))\n",
    "input_label = tf.placeholder(tf.float32,shape=(None,10))\n",
    "\n",
    "hidden1 = tf.keras.layers.Conv2D(filters=16,kernel_size=3,strides=1,padding='valid',activation='relu')(input_data)\n",
    "\n",
    "hidden2 = tf.keras.layers.Conv2D(filters=32,kernel_size=3,strides=2,padding='same',activation='relu')(hidden1)\n",
    "\n",
    "hidden3 = tf.keras.layers.Conv2D(filters=64,kernel_size=3,strides=2,padding='valid',activation='relu')(hidden2)\n",
    "\n",
    "hidden4 = tf.keras.layers.MaxPool2D(pool_size=2)(hidden3)\n",
    "\n",
    "hidden5 = tf.keras.layers.Conv2D(filters=128,kernel_size=3,strides=2,padding='valid',activation='relu')(hidden4)\n",
    "\n",
    "hidden5 = tf.layers.Flatten()(hidden5)\n",
    "\n",
    "output = tf.keras.layers.Dense(10,activation='softmax')(hidden5)\n",
    "\n",
    "#损失函数\n",
    "loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(input_label,output))\n",
    "\n",
    "#优化器\n",
    "opt = tf.train.AdamOptimizer()\n",
    "train_op = opt.minimize(loss)\n",
    "\n",
    "# 测试评估\n",
    "acc = tf.reduce_mean(tf.keras.metrics.categorical_accuracy(input_label,output))\n",
    "#correct_pred = tf.equal(tf.argmax(input_label,axis=1),tf.argmax(output,axis=1))\n",
    "#acc = tf.reduce_mean(tf.cast(correct_pred,tf.float32))\n",
    "\n",
    "tf.add_to_collection('my_op',input_data)\n",
    "tf.add_to_collection('my_op',output)\n",
    "tf.add_to_collection('my_op',loss)\n",
    "\n",
    "init = tf.global_variables_initializer()\n",
    "saver = tf.train.Saver()\n",
    "with tf.Session() as sess:\n",
    "    sess.run([init])\n",
    "    test_data = mnist.test.images\n",
    "    test_label = mnist.test.labels\n",
    "    test_data = test_data.reshape(-1,28,28,1)\n",
    "    while epoch<epochs:\n",
    "        data,label=sess.run([next_element[0],next_element[1]])\n",
    "        total_sum+=batch_size\n",
    "        sess.run([train_op],feed_dict={input_data:data,input_label:label})\n",
    "        if total_sum//train_num>epoch:\n",
    "            epoch = total_sum//train_num\n",
    "            loss_val = sess.run([loss],feed_dict={input_data:data,input_label:label})\n",
    "            acc_test = sess.run([acc],feed_dict={input_data:test_data,input_label:test_label})\n",
    "            saver.save(sess, save_path=\"./model/my_model.ckpt\")\n",
    "            print('epoch:{},train_loss:{:.4f},test_acc:{:.4f}'.format(epoch,loss_val[0],acc_test[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
